{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from model import Transformer\n",
    "from config import get_config, get_weights_file_path\n",
    "from train import get_model, get_ds, greedy_decode\n",
    "import altair as alt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Using device:\", device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = get_config()\n",
    "train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)\n",
    "model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)\n",
    "\n",
    "# Load the pretrained weights\n",
    "model_filename = get_weights_file_path(config, f\"29\")\n",
    "state = torch.load(model_filename)\n",
    "model.load_state_dict(state['model_state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_next_batch():\n",
    "    # Load a sample batch from the validation set\n",
    "    batch = next(iter(val_dataloader))\n",
    "    encoder_input = batch[\"encoder_input\"].to(device)\n",
    "    encoder_mask = batch[\"encoder_mask\"].to(device)\n",
    "    decoder_input = batch[\"decoder_input\"].to(device)\n",
    "    decoder_mask = batch[\"decoder_mask\"].to(device)\n",
    "\n",
    "    encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]\n",
    "    decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]\n",
    "\n",
    "    # check that the batch size is 1\n",
    "    assert encoder_input.size(\n",
    "        0) == 1, \"Batch size must be 1 for validation\"\n",
    "\n",
    "    model_out = greedy_decode(\n",
    "        model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)\n",
    "    \n",
    "    return batch, encoder_input_tokens, decoder_input_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mtx2df(m, max_row, max_col, row_tokens, col_tokens):\n",
    "    return pd.DataFrame(\n",
    "        [\n",
    "            (\n",
    "                r,\n",
    "                c,\n",
    "                float(m[r, c]),\n",
    "                \"%.3d %s\" % (r, row_tokens[r] if len(row_tokens) > r else \"<blank>\"),\n",
    "                \"%.3d %s\" % (c, col_tokens[c] if len(col_tokens) > c else \"<blank>\"),\n",
    "            )\n",
    "            for r in range(m.shape[0])\n",
    "            for c in range(m.shape[1])\n",
    "            if r < max_row and c < max_col\n",
    "        ],\n",
    "        columns=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
    "    )\n",
    "\n",
    "def get_attn_map(attn_type: str, layer: int, head: int):\n",
    "    if attn_type == \"encoder\":\n",
    "        attn = model.encoder.layers[layer].self_attention_block.attention_scores\n",
    "    elif attn_type == \"decoder\":\n",
    "        attn = model.decoder.layers[layer].self_attention_block.attention_scores\n",
    "    elif attn_type == \"encoder-decoder\":\n",
    "        attn = model.decoder.layers[layer].cross_attention_block.attention_scores\n",
    "    return attn[0, head].data\n",
    "\n",
    "def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):\n",
    "    df = mtx2df(\n",
    "        get_attn_map(attn_type, layer, head),\n",
    "        max_sentence_len,\n",
    "        max_sentence_len,\n",
    "        row_tokens,\n",
    "        col_tokens,\n",
    "    )\n",
    "    return (\n",
    "        alt.Chart(data=df)\n",
    "        .mark_rect()\n",
    "        .encode(\n",
    "            x=alt.X(\"col_token\", axis=alt.Axis(title=\"\")),\n",
    "            y=alt.Y(\"row_token\", axis=alt.Axis(title=\"\")),\n",
    "            color=\"value\",\n",
    "            tooltip=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
    "        )\n",
    "        #.title(f\"Layer {layer} Head {head}\")\n",
    "        .properties(height=400, width=400, title=f\"Layer {layer} Head {head}\")\n",
    "        .interactive()\n",
    "    )\n",
    "\n",
    "def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):\n",
    "    charts = []\n",
    "    for layer in layers:\n",
    "        rowCharts = []\n",
    "        for head in heads:\n",
    "            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))\n",
    "        charts.append(alt.hconcat(*rowCharts))\n",
    "    return alt.vconcat(*charts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()\n",
    "print(f'Source: {batch[\"src_text\"][0]}')\n",
    "print(f'Target: {batch[\"tgt_text\"][0]}')\n",
    "sentence_len = encoder_input_tokens.index(\"[PAD]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = [0, 1, 2]\n",
    "heads = [0, 1, 2, 3, 4, 5, 6, 7]\n",
    "\n",
    "# Encoder Self-Attention\n",
    "get_all_attention_maps(\"encoder\", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encoder Self-Attention\n",
    "get_all_attention_maps(\"decoder\", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encoder Self-Attention\n",
    "get_all_attention_maps(\"encoder-decoder\", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "transformer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
